import unittest
from public.read_json import read_json
from epics import caget,caput
from public._snmp import snmpget
from re import sub
from public.logger import getLogger
from random import randint
import time
import csv
from os import listdir,mkdir
class file_opr():
    def __init__(self,pv_list):
        self.pv_list = pv_list
        if not "output" in listdir("."):
            mkdir("output")
        if not "Test_width.csv" in listdir("./output"):
            with open("./output/Test_width.csv","w",newline="",encoding="utf-8") as fp:
                fields = ["PV","测试次数","成功次数","成功率", "累计执行时间（s）","平均执行时间（s）"]
                writer = csv.DictWriter(fp,fields)

                writer.writeheader()
                for name in pv_list:
                    pvname = name + "_W"
                    writer.writerow({"PV":pvname,"测试次数":0,"成功次数":0,"成功率":"100%", "累计执行时间（s）": 0.0,"平均执行时间（s）": 0.0})
                fp.close()
        if not "Test_delay.csv" in listdir("./output"):
            with open("./output/Test_delay.csv","w",newline="",encoding="utf-8") as fp:
                fields = ["PV","测试次数","成功次数","成功率", "累计执行时间（s）","平均执行时间（s）"]
                writer = csv.DictWriter(fp,fields)
                writer.writeheader()
                for name in pv_list:
                    pvname = name + "_D"
                    writer.writerow({"PV":pvname,"测试次数":0,"成功次数":0,"成功率":"100%", "累计执行时间（s）": 0.0,"平均执行时间（s）": 0.0})
                fp.close()
        if not "Test_enable.csv" in listdir("./output"):
            with open("./output/Test_enable.csv","w",newline="",encoding="utf-8") as fp:
                fields = ["PV","测试次数","成功次数","成功率","累计执行时间（s）","平均执行时间（s）"]
                writer = csv.DictWriter(fp,fields)
                writer.writeheader()
                for name in pv_list:
                    pvname = name + "_Enable"
                    writer.writerow({"PV":pvname,"测试次数":0,"成功次数":0,"成功率":"100%", "累计执行时间（s）": 0.0,"平均执行时间（s）": 0.0})
                fp.close()
        if not "Test_freqence.csv" in listdir("./output"):
            with open("./output/Test_freqence.csv","w",newline="",encoding="utf-8") as fp:
                fields = ["下发值","回读计算值","是否成功", "累计执行时间（s）","平均执行时间（s）"]
                writer = csv.DictWriter(fp,fields)
                writer.writeheader()
                fp.close()
        with open("./output/Test_enable.csv".format(type),"r",newline="",encoding="utf-8") as fp:
            self.csv_enable_data = list(csv.reader(fp))
            fp.close()
        with open("./output/Test_delay.csv".format(type),"r",newline="",encoding="utf-8") as fp:
            self.csv_delay_data = list(csv.reader(fp))
            fp.close()
        with open("./output/Test_width.csv".format(type),"r",newline="",encoding="utf-8") as fp:
            self.csv_width_data = list(csv.reader(fp))
            fp.close()
    def write_data(self,type=None,name = None,flag=None, exec_time=None):
            index = self.pv_list.index(name)+1
            if type == "enable":
                self.csv_enable_data[index][1] = int(self.csv_enable_data[index][1])+1
                if flag:
                    self.csv_enable_data[index][2] = int(self.csv_enable_data[index][2]) + 1
                self.csv_enable_data[index][3] = "{:.2f}%".format(float(self.csv_enable_data[index][2])/float(self.csv_enable_data[index][1])*100)
                 # 更新累计执行时间和平均执行时间
                if exec_time is not None:
                    current_total_time = float(self.csv_enable_data[index][4])
                    current_test_count = int(self.csv_enable_data[index][1])
                    new_total_time = current_total_time + exec_time
                    self.csv_enable_data[index][4] = new_total_time  # 累计执行时间
                    self.csv_enable_data[index][5] = "{:.4f}".format(new_total_time / current_test_count)  # 计算平均时间
            elif type == "delay":
                self.csv_delay_data[index][1] = int(self.csv_delay_data[index][1])+1
                if flag:
                    self.csv_delay_data[index][2] = int(self.csv_delay_data[index][2]) + 1
                self.csv_delay_data[index][3] = "{:.2f}%".format(float(self.csv_delay_data[index][2])/float(self.csv_delay_data[index][1])*100)
                # 更新累计执行时间和平均执行时间
                if exec_time is not None:
                    current_total_time = float(self.csv_delay_data[index][4])
                    current_test_count = int(self.csv_delay_data[index][1])
                    new_total_time = current_total_time + exec_time
                    self.csv_delay_data[index][4] = new_total_time  # 累计执行时间
                    self.csv_delay_data[index][5] = "{:.4f}".format(new_total_time / current_test_count)  # 计算平均时间

            elif type == "width":
                self.csv_width_data[index][1] = int(self.csv_width_data[index][1])+1
                if flag:
                    self.csv_width_data[index][2] = int(self.csv_width_data[index][2]) + 1
                self.csv_width_data[index][3] = "{:.2f}%".format(float(self.csv_width_data[index][2])/float(self.csv_width_data[index][1])*100)
                if exec_time is not None:
                    current_total_time = float(self.csv_width_data[index][4])
                    current_test_count = int(self.csv_width_data[index][1])
                    new_total_time = current_total_time + exec_time
                    self.csv_width_data[index][4] = new_total_time  # 更新累计执行时间
                    # 计算平均执行时间
                    self.csv_width_data[index][5] = "{:.4f}".format(new_total_time / current_test_count)
            
    def write_csv(self):
        with open("./output/Test_enable.csv".format(type),"w",newline="",encoding="utf-8") as fp:
            csv_write = csv.writer(fp)
            csv_write.writerows(self.csv_enable_data)
            fp.close()
        with open("./output/Test_width.csv".format(type),"w",newline="",encoding="utf-8") as fp:
            csv_write = csv.writer(fp)
            csv_write.writerows(self.csv_width_data)
            fp.close()
        with open("./output/Test_delay.csv".format(type),"w",newline="",encoding="utf-8") as fp:
            csv_write = csv.writer(fp)
            csv_write.writerows(self.csv_delay_data)
            fp.close()

    def write_freqence_csv(self,send_value,read_value,flag,exec_time=None):
        with open("./output/Test_freqence.csv".format(type),"a+",newline="",encoding="utf-8") as fp:
            csv_write = csv.writer(fp)
            # print({"下发值":send_value,"回读计算值":read_value,"是否成功":flag})
            csv_write.writerow([send_value,read_value,flag,"{:.4f}".format(exec_time) if exec_time else "N/A"])
            fp.close()
            
        
class snmp_get_data():
     
    def __init__(self,host_dict,index_dict,name_dict) -> None:
        self.host_dict = host_dict
        self.index_dict = index_dict
        self.name_dict = name_dict

    def enable(self, pvname):
        if self.name_dict[pvname] == "Node":
            file_path = "./WR-IMP-MIB-node.txt"
        else:
            file_path = "./WR-IMP-MIB-multi.txt"
        return_data = snmpget(host=self.host_dict[pvname],file_path=file_path,
                              OIDname="impCtrlMC{}Ena".format(self.index_dict[pvname]))
        if not "Error" in return_data:
            msg = return_data.split("INTEGER:")
            msg = sub("[^\d]", "", msg[1])
            enable = int(msg)
            return enable

    def width(self, pvname):
        if self.name_dict[pvname] == "Node":
            file_path = "./WR-IMP-MIB-node.txt"
        else:
            file_path = "./WR-IMP-MIB-multi.txt"
        return_data = snmpget(OIDname="impWavePatternWidth.0", file_path=file_path,host=self.host_dict[pvname])
        if not "Error" in return_data:
            msg = return_data.split("Gauge32:")
            msg = sub("[^\d]", "", msg[1])
            width = float(msg) / 250
            return width


    def delay(self, pvname):
        if self.name_dict[pvname] == "Node":
            file_path = "./WR-IMP-MIB-node.txt"
        else:
            file_path = "./WR-IMP-MIB-multi.txt"
        return_data = snmpget(OIDname="impWavePatternDelay.0", file_path=file_path,host=self.host_dict[pvname])
        if not "Error" in return_data:
            msg = return_data.split("Gauge32:")
            msg = sub("[^\d]", "", msg[1])
            delay= float(msg) / 250
            return delay
    
    def idle_4ns(self, pvname):
        if self.name_dict[pvname] == "Node":
            file_path = "./WR-IMP-MIB-node.txt"
        else:
            file_path = "./WR-IMP-MIB-multi.txt"
        return_data = snmpget(OIDname="impWavePatternIdle4ns.0", file_path=file_path,host=self.host_dict[pvname])
        if not "Error" in return_data:
            msg = return_data.split("Gauge32:")
            msg = sub("[^\d]", "", msg[1])
            idle_4ns= int(msg)
            return idle_4ns
    
    def idle_100ms(self, pvname):
        if self.name_dict[pvname] == "Node":
            file_path = "./WR-IMP-MIB-node.txt"
        else:
            file_path = "./WR-IMP-MIB-multi.txt"
        return_data = snmpget(OIDname="impWavePatternIdle100ms.0", file_path=file_path,host=self.host_dict[pvname])
        if not "Error" in return_data:
            msg = return_data.split("Gauge32:")
            msg = sub("[^\d]", "", msg[1])
            idle_100ms= int(msg)
            return idle_100ms

    def get_all_date(self,pvname):
        return self.width(pvname),self.delay(pvname),self.enable(pvname),self.idle_4ns(pvname),self.idle_100ms
    
    def get_all_width(self):
        result = {}
        for key in self.host_dict.keys():
            result[key] = self.width(key)
        return result

    def get_all_delay(self):
        result = {}
        for key in self.host_dict.keys():
            result[key] = self.delay(key)
        return result

    def get_all_enable(self):
        result = {}
        for key in self.host_dict.keys():
            result[key] = self.enable(key)
        return result
    def get_all_4ns(self):
        result = {}
        for key in self.host_dict.keys():
            result[key] = self.idle_4ns(key)
        return result
    def get_all_100ms(self):
        result = {}
        for key in self.host_dict.keys():
            result[key] = self.idle_100ms(key)
        return result

class Test(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        cls.pv_list = []
        cls.host_dict = {}
        cls.index_dict = {}
        cls.name_dict = {}
        cls.FRQ = 'TIMS:WR:Frq_Set'
        cls.mulit_data,cls.node_data = read_json()
        for i in cls.mulit_data:
            for key, value in i["data"].items():
                cls.host_dict['{}'.format(key)] = i["host"]
                cls.index_dict['{}'.format(key)] = value
                cls.name_dict['{}'.format(key)] = i["name"]
                cls.pv_list.append(key)
        for key, value in cls.node_data.items():
            cls.host_dict['{}'.format(key)] = value
            cls.index_dict['{}'.format(key)] = 0
            cls.name_dict['{}'.format(key)] = "Node"
            cls.pv_list.append(key)
        cls.file_opr = file_opr(cls.pv_list)
        cls.snmp_opr = snmp_get_data(host_dict=cls.host_dict,index_dict=cls.index_dict,name_dict=cls.name_dict)
        return super().setUpClass()
    
    @classmethod
    def tearDownClass(cls) -> None:
        cls.file_opr.write_csv()
        return super().tearDownClass()
    def setUp(self) -> None:
        self.start_tiem = time.perf_counter()

        return super().setUp()
    
    def tearDown(self) -> None:
        _time = time.perf_counter() - self.start_tiem
        self.logger.info("测试用例运行时间:{}".format(_time))
        return super().tearDown()
    
    def width(self):
        self.logger = getLogger("Test_width")
        self.logger.info("-----开始测试脉宽-----")
        for name in self.pv_list:
            pvname = name + "_W"
            if self.name_dict[name] == "Node":
                fornum = 10
            else:
                fornum = 1
            for i in range(fornum):
                try:
                    pv_value = caget(pvname=pvname)
                    self.logger.info("当前PV:{}\t 当前值:{}".format(pvname,pv_value))
                    randnum = randint(1,100000)
                    self.logger.info("开始生成随机数\t 随机数值:{}".format(randnum))
                    flag = caput(pvname,randnum)
                    start_time = time.perf_counter()
                    if flag:
                        self.logger.info("数据下发成功")
                        time.sleep(0.3)
                        snmp_get_data = self.snmp_opr.width(name)
                        self.logger.info("snmp数据:{}\t PV下发数据:{}".format(snmp_get_data,randnum))
                        exec_time = time.perf_counter() - start_time
                        self.logger.info("开始对比数据:{}\n".format(snmp_get_data == randnum))
                        if snmp_get_data == randnum:
                            self.file_opr.write_data(type="width",name=name,flag=True,exec_time=exec_time)
                        else:
                            self.file_opr.write_data(type="width",name=name,flag=False,exec_time=exec_time)
                    else:
                        self.logger.info("数据下发失败")
                    
                    
                except Exception as e:
                    self.logger.error("当前PV:{}\t 链接失败: {}".format(pvname,str(e)))


    def delay(self):
        self.logger = getLogger("Test_delay")
        self.logger.info("-----开始测试延时-----")
        for name in self.pv_list:
            pvname = name + "_D"
            if self.name_dict[name] == "Node":
                fornum = 10
            else:
                fornum = 1
            for i in range(fornum):
                # try:
                    pv_value = caget(pvname=pvname)
                    self.logger.info("当前PV:{}\t 当前值:{}".format(pvname,pv_value))
                    randnum = randint(0,100000)
                    self.logger.info("开始生成随机数\t 随机数值:{}".format(randnum))
                    flag = caput(pvname,randnum)
                    start_time = time.perf_counter()
                    if flag:
                        self.logger.info("数据下发成功")
                        time.sleep(0.3)
                        snmp_get_data = self.snmp_opr.delay(name)
                        self.logger.info("snmp数据:{}\t PV下发数据:{}".format(snmp_get_data,randnum))
                        exec_time = time.perf_counter() - start_time
                        self.logger.info("开始对比数据:{}\n".format(snmp_get_data == randnum))
                        if snmp_get_data == randnum:
                            self.file_opr.write_data(type="delay",name=name,flag=True,exec_time=exec_time)
                        else:
                            self.file_opr.write_data(type="delay",name=name,flag=False,exec_time=exec_time)
                    else:
                        self.logger.info("数据下发失败")
                    
                # except Exception as e:
                #     self.logger.error("当前PV:{}\t 链接失败: {}".format(pvname,str(e)))

     
    def enable(self):
        self.logger = getLogger("Test_enable")
        self.logger.info("-----开始测试使能-----")
        for name in self.pv_list:
            pvname = name + "_Enable"
            if self.name_dict[name] == "Node":
                fornum = 10
            else:
                fornum = 1
            for i in range(fornum):
                # try:
                    pv_value = caget(pvname=pvname)
                    self.logger.info("当前PV:{}\t 当前值:{}".format(pvname,pv_value))
                    if pv_value == 0:
                        randnum = 1
                    elif pv_value == 1:
                        randnum = 0
                    flag = caput(pvname,randnum)
                    start_time = time.perf_counter()
                    if flag:
                        self.logger.info("数据下发成功")
                        time.sleep(0.3)
                        snmp_get_data = self.snmp_opr.enable(name)
                        self.logger.info("snmp数据:{}\t PV下发数据:{}".format(snmp_get_data,randnum))
                        exec_time = time.perf_counter() - start_time
                        self.logger.info("开始对比数据:{}\n".format(snmp_get_data == randnum))
                        if snmp_get_data == randnum:
                            self.file_opr.write_data(type="enable",name=name,flag=True,exec_time=exec_time)
                        else:
                            self.file_opr.write_data(type="enable",name=name,flag=False,exec_time=exec_time)
                    else:
                        self.logger.info("数据下发失败")
                    
                # except Exception as e:
                #     self.logger.error("当前PV:{}\t 链接失败: {}".format(pvname,str(e)))
    def freqence(self):
        self.logger = getLogger("Test_freqence")
        self.logger.info("-----开始测试频率-----")
        pvname = self.FRQ
        pv_value = caget(pvname=pvname)
        self.logger.info("当前频率值:{}".format(pv_value))
        frq_list = [1,2,3,4,5,6,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.01]
        randnum = randint(0,15)
        randnum = frq_list[randnum]
        self.logger.info("开始生成随机频率[1,20]\t 随机频率为:{}".format(randnum))
        flag = caput(pvname,randnum)
        start_time = time.perf_counter()
        if flag:
            time.sleep(10)
            while not caget(pvname=pvname) == randnum:
                time.sleep(1)
            self.logger.info("频率下发成功")
            delay = []
            width = []
            idle_4ns = []
            idle_100ms = []
            for data in self.mulit_data:
                file_path = "./WR-IMP-MIB-multi.txt"
                # 读取 Delay 和 Width 数据
                return_data = snmpget(
                    OIDname="impWavePatternTable1", file_path= file_path,host=data["host"])
                if not "Error" in return_data:
                    try:
                        Delay_list = return_data.split("\n")[0:10]
                        Width_list = return_data.split("\n")[10:20]
                        idle_4ns_list = return_data.split("\n")[20:30]
                        idle_100ms_list = return_data.split("\n")[30:40]
                        for  value in Delay_list:
                            msg = value.split("Gauge32:")
                            msg = sub("[^\d]", "", msg[1])
                            delay.append(int(msg))
                        for  value in Width_list:
                            msg = value.split("Gauge32:")
                            msg = sub("[^\d]", "", msg[1])
                            width.append(int(msg))
                        for  value in idle_4ns_list:
                            msg = value.split("Gauge32:")
                            msg = sub("[^\d]", "", msg[1])
                            idle_4ns.append(int(msg))
                        for  value in idle_100ms_list:
                            msg = value.split("Gauge32:")
                            msg = sub("[^\d]", "", msg[1])
                            idle_100ms.append(int(msg))
                    except Exception as e:
                        self.logger.error("出现异常：", e)
            for host in self.node_data.values():
                return_data = snmpget(
                    OIDname="impWavePatternTable1", file_path= file_path,host=host)
                if not "Error" in return_data:
                    try:
                        value = return_data.split("\n")[16*0]
                        msg = value.split("Gauge32:")
                        msg = sub("[^\d]", "", msg[1])
                        delay.append(int(msg))
                        value = return_data.split("\n")[16*1]
                        msg = value.split("Gauge32:")
                        msg = sub("[^\d]", "", msg[1])
                        width.append(int(msg))
                        value = return_data.split("\n")[16*2]
                        msg = value.split("Gauge32:")
                        msg = sub("[^\d]", "", msg[1])
                        idle_4ns.append(int(msg))
                        value = return_data.split("\n")[16*3]
                        msg = value.split("Gauge32:")
                        msg = sub("[^\d]", "", msg[1])
                        idle_100ms.append(int(msg))
                    except Exception as e:
                        self.logger.error("出现异常：", e)
            exec_time = time.perf_counter() - start_time 
            pv_value = caget(pvname=pvname)
            sum_time = 4*sum(delay+width+idle_4ns)*10**-9+sum(idle_100ms)*0.1
            calc_value =   len(width+idle_4ns+idle_100ms)/sum_time/3
            if  pv_value - calc_value <= 0.0001:
                self.logger.info("频率下发成功，设定值:{},机箱回读计算结果为:{}".format(randnum,calc_value))
                self.file_opr.write_freqence_csv(randnum,calc_value,flag=True,exec_time=exec_time)
            else:
                self.logger.info("频率下发失败，设定值:{},机箱回读计算结果为:{}".format(randnum,calc_value))
                self.file_opr.write_freqence_csv(randnum,calc_value,flag=False,exec_time=exec_time)


if __name__ == "__main__":
    num = 1
    while True:
        print("第{}次测试开始".format(num))
        suite = unittest.TestSuite()
        case_list = [Test("width"),Test("delay"),Test("enable"),Test("freqence")]
        # case_list = [Test("freqence")]
        suite.addTests(tests=case_list)
        runner = unittest.TextTestRunner()
        runner.run(suite)
        
        num += 1
